{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wJcYs_ERTnnI"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "HMUDt0CiUJk9"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "77z2OchJTk0l"
      },
      "source": [
        "# Migrate checkpoint saving\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/migrate/checkpoint_saver\">\n",
        "    <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />\n",
        "    View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/migrate/checkpoint_saver.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/checkpoint_saver.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/migrate/checkpoint_saver.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hIo_p2FWFIRx"
      },
      "source": [
        "Continually saving the \"best\" model or model weights/parameters has many benefits. These include being able to track the training progress and load saved models from different saved states.\n",
        "\n",
        "In TensorFlow 1, to configure checkpoint saving during training/validation with the `tf.estimator.Estimator` APIs, you specify a schedule in `tf.estimator.RunConfig` or use `tf.estimator.CheckpointSaverHook`. This guide demonstrates how to migrate from this workflow to TensorFlow 2 Keras APIs.\n",
        "\n",
        "In TensorFlow 2, you can configure `tf.keras.callbacks.ModelCheckpoint` in a number of ways:\n",
        "\n",
        "- Save the \"best\" version according to a metric monitored using the `save_best_only=True` parameter, where `monitor` can be, for example, `'loss'`, `'val_loss'`, `'accuracy', or `'val_accuracy'`.\n",
        "- Save continually at a certain frequency (using the `save_freq` argument).\n",
        "- Save the weights/parameters only instead of the whole model by setting `save_weights_only` to `True`.\n",
        "\n",
        "For more details, refer to the `tf.keras.callbacks.ModelCheckpoint` API docs and the *Save checkpoints during training* section in the [Save and load models](../../tutorials/keras/save_and_load.ipynb) tutorial. Learn more about the Checkpoint format in the *TF Checkpoint format* section in the [Save and load Keras models](https://www.tensorflow.org/guide/keras/save_and_serialize) guide. In addition, to add fault tolerance, you can use `tf.keras.callbacks.BackupAndRestore` or `tf.train.Checkpoint` for manual checkpointing. Learn more in the [Fault tolerance migration guide](fault_tolerance.ipynb).\n",
        "\n",
        "Keras [callbacks](https://www.tensorflow.org/guide/keras/custom_callback) are objects that are called at different points during training/evaluation/prediction in the built-in Keras `Model.fit`/`Model.evaluate`/`Model.predict` APIs. Learn more in the _Next steps_ section at the end of the guide."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f55c103999de"
      },
      "source": [
        "## Setup\n",
        "\n",
        "Start with imports and a simple dataset for demonstration purposes:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X74yjOb-e18w"
      },
      "outputs": [],
      "source": [
        "import tensorflow.compat.v1 as tf1\n",
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "import tempfile"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2r8r4d8FfMny"
      },
      "outputs": [],
      "source": [
        "mnist = tf.keras.datasets.mnist\n",
        "\n",
        "(x_train, y_train),(x_test, y_test) = mnist.load_data()\n",
        "x_train, x_test = x_train / 255.0, x_test / 255.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wrqBkG4RFLP_"
      },
      "source": [
        "## TensorFlow 1: Save checkpoints with tf.estimator APIs\n",
        "\n",
        "This TensorFlow 1 example shows how to configure `tf.estimator.RunConfig` to save checkpoints at every step during training/evaluation with the `tf.estimator.Estimator` APIs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "upA8nuf3FEq5"
      },
      "outputs": [],
      "source": [
        "feature_columns = [tf1.feature_column.numeric_column(\"x\", shape=[28, 28])]\n",
        "\n",
        "config = tf1.estimator.RunConfig(save_summary_steps=1,\n",
        "                                 save_checkpoints_steps=1)\n",
        "\n",
        "path = tempfile.mkdtemp()\n",
        "\n",
        "classifier = tf1.estimator.DNNClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    hidden_units=[256, 32],\n",
        "    optimizer=tf1.train.AdamOptimizer(0.001),\n",
        "    n_classes=10,\n",
        "    dropout=0.2,\n",
        "    model_dir=path,\n",
        "    config = config\n",
        ")\n",
        "\n",
        "train_input_fn = tf1.estimator.inputs.numpy_input_fn(\n",
        "    x={\"x\": x_train},\n",
        "    y=y_train.astype(np.int32),\n",
        "    num_epochs=10,\n",
        "    batch_size=50,\n",
        "    shuffle=True,\n",
        ")\n",
        "\n",
        "test_input_fn = tf1.estimator.inputs.numpy_input_fn(\n",
        "    x={\"x\": x_test},\n",
        "    y=y_test.astype(np.int32),\n",
        "    num_epochs=10,\n",
        "    shuffle=False\n",
        ")\n",
        "\n",
        "train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)\n",
        "eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,\n",
        "                                   steps=10,\n",
        "                                   throttle_secs=0)\n",
        "\n",
        "tf1.estimator.train_and_evaluate(estimator=classifier,\n",
        "                                train_spec=train_spec,\n",
        "                                eval_spec=eval_spec)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3u96G4MtRVqU"
      },
      "outputs": [],
      "source": [
        "%ls {classifier.model_dir}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QvE_uxDJFUX-"
      },
      "source": [
        "## TensorFlow 2: Save checkpoints with a Keras callback for Model.fit\n",
        "\n",
        "In TensorFlow 2, when you use the built-in Keras `Model.fit` (or `Model.evaluate`) for training/evaluation, you can configure `tf.keras.callbacks.ModelCheckpoint` and then pass it to the `callbacks` parameter of `Model.fit` (or `Model.evaluate`). (Learn more in the API docs and the *Using callbacks* section in the [Training and evaluation with the built-in methods](https://www.tensorflow.org/guide/keras/train_and_evaluate) guide.)\n",
        "\n",
        "In the example below, you will use a `tf.keras.callbacks.ModelCheckpoint` callback to store checkpoints in a temporary directory:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9FLBhT2BFX2H"
      },
      "outputs": [],
      "source": [
        "def create_model():\n",
        "  return tf.keras.models.Sequential([\n",
        "    tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
        "    tf.keras.layers.Dense(512, activation='relu'),\n",
        "    tf.keras.layers.Dropout(0.2),\n",
        "    tf.keras.layers.Dense(10, activation='softmax')\n",
        "  ])\n",
        "\n",
        "model = create_model()\n",
        "model.compile(optimizer='adam',\n",
        "              loss='sparse_categorical_crossentropy',\n",
        "              metrics=['accuracy'],\n",
        "              steps_per_execution=10)\n",
        "\n",
        "log_dir = tempfile.mkdtemp()\n",
        "\n",
        "model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n",
        "    filepath=log_dir)\n",
        "\n",
        "model.fit(x=x_train,\n",
        "          y=y_train,\n",
        "          epochs=10,\n",
        "          validation_data=(x_test, y_test),\n",
        "          callbacks=[model_checkpoint_callback])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SROSmhyyLBA-"
      },
      "outputs": [],
      "source": [
        "%ls {model_checkpoint_callback.filepath}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rQUS8nO9FZlH"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "Learn more about checkpointing in:\n",
        "\n",
        "- API docs: `tf.keras.callbacks.ModelCheckpoint`\n",
        "- Tutorial: [Save and load models](../../tutorials/keras/save_and_load.ipynb) (the *Save checkpoints during training* section)\n",
        "- Guide: [Save and load Keras models](https://www.tensorflow.org/guide/keras/save_and_serialize) (the *TF Checkpoint format* section)\n",
        "\n",
        "Learn more about callbacks in:\n",
        "\n",
        "- API docs: `tf.keras.callbacks.Callback`\n",
        "- Guide: [Writing your own callbacks](https://www.tensorflow.org/guide/keras/guide/keras/custom_callback)\n",
        "- Guide: [Training and evaluation with the built-in methods](https://www.tensorflow.org/guide/keras/train_and_evaluate) (the *Using callbacks* section)\n",
        "\n",
        "You may also find the following migration-related resources useful:\n",
        "\n",
        "- The [Fault tolerance migration guide](fault_tolerance.ipynb): `tf.keras.callbacks.BackupAndRestore` for `Model.fit`, or `tf.train.Checkpoint` and `tf.train.CheckpointManager` APIs for a custom training loop\n",
        "- The [Early stopping migration guide](early_stopping.ipynb): `tf.keras.callbacks.EarlyStopping` is a built-in early stopping callback\n",
        "- The [TensorBoard migration guide](tensorboard.ipynb): TensorBoard enables tracking and displaying metrics\n",
        "- The [LoggingTensorHook and StopAtStepHook to Keras callbacks migration guide](logging_stop_hook.ipynb)\n",
        "- The [SessionRunHook to Keras callbacks guide](sessionrunhook_callback.ipynb)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "checkpoint_saver.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
